using LinearAlgebra, Distributions, StatsBase, Plots,MLDatasets
using Statistics
using Nemo, Images
using ToeplitzMatrices
using Random
using Kronecker
using StatPlots
using SpecialFunctions: erfc
function generate_statistic(k,m,p)
    #Random generation of M
    M=zeros(p,k*m)
    for i=1:k
        for j=1:m
            M[:,m*(i-1)+j]=0.08*randn(p,1);
        end
    end
    #Deterministic genration of M
    # M[:,1]=[1;0;0;0;zeros(p-4,1)];
    # M[:,2]=[-1;0;0;0;zeros(p-4,1)];
    # μ_orth=[0;0;0;1;zeros(p-4,1)];
    # M[:,3]=β*M[:,1]+√(1-β^2)*μ_orth;
    # M[:,4]=β*M[:,2]+√(1-β^2)*μ_orth;
    α=abs.(rand(m*k,1));
    Σ=zeros(p,p,m*k);
    for i=1:k
        for j=1:m
            Σ[:,:,m*(i-1)+j]=Toeplitz(α[m*(i-1)+j].^(0:p-1),α[m*(i-1)+j].^(0:p-1));
        end
    end
    return M,Σ
end

function generate_data(p,ns,k,m,β,M,Σ,data_type,n_test,S,T)
    # S.fts = Array{Any}(undef,k-1)
    # S.labels = Array{Any}(undef,k-1)
    # X_test = Array{Any}(undef,k)
    # y_test = Array{Any}(undef,k)
    i=1;
    Sfts=reduce(hcat, [M[:,m*(i-1)+j]*ones(1,ns[m*(i-1)+j])+Σ[:,:,m*(i-1)+j]*randn(p,ns[m*(i-1)+j]) for j=1:m]);
    Slabels=[ones(ns[m*(i-1)+1],1);-ones(ns[m*(i-1)+2],1)];
    Tfts=reduce(hcat, [M[:,m*(k-1)+j]*ones(1,ns[m*(k-1)+j])+Σ[:,:,m*(k-1)+j]*randn(p,ns[m*(k-1)+j]) for j=1:m]);
    Tlabels=[ones(ns[m*(k-1)+1],1);-ones(ns[m*(k-1)+2],1)];
    X_test=reduce(hcat, [M[:,m*(1-1)+j]*ones(1,n_test[m*(1-1)+j])+Σ[:,:,m*(1-1)+j]*randn(p,n_test[m*(1-1)+j]) for j=1:m]);
    y_test=[ones(n_test[m*(1-1)+1],1);-ones(n_test[m*(1-1)+2],1)];
    X_test1=reduce(hcat, [M[:,m*(2-1)+j]*ones(1,n_test[m*(2-1)+j])+Σ[:,:,m*(2-1)+j]*randn(p,n_test[m*(2-1)+j]) for j=1:m]);
    y_test1=[ones(n_test[m*(2-1)+1],1);-ones(n_test[m*(2-1)+2],1)];
    return Sfts,Slabels,Tfts,Tlabels,X_test,y_test,X_test1,y_test1
end
function RMTMTLLSSVM(Sfts,Slabels,Tfts,Tlabels,λ,γ,X_test,y_test,X_test1,y_test1,ns)
        #Define parameters
        p,nul=size(Tfts);
        #k=numel(X1)+1;
        k=2;
        # ns[m*(i-1)+1]=sum(Slabels==1)
        # ns[m*(i-1)+2]=sum(Slabels==-1)
        # ns[m*(k-1)+1]=sum(Tlabels==1)
        # ns[m*(k-1)+2]=sum(Tlabels==-1);
        n=sum(ns);
        c=ns/sum(ns);
        co=k*p/n;
        n1=sum(ns[1:2]);n2=sum(ns[3:4]);
        P=zeros(n,k);
        P[1:n1,1]=ones(n1,1);P[n1+1:end,2]=ones(n2,1);
        A=kron(diagm(γ)+λ*ones(k,1)*ones(1,k),Matrix{Float64}(I, p, p));
        #Compute statististics of the data.
        #Centering and scaling of the data;
        Mi,Σi=compute_statistics(Sfts,Tfts,ns);
        Sfts_c=(Sfts*(Matrix{Float64}(I, n1, n1)-(1/n1)*ones(n1,1)*ones(1,n1)));
        Sfts_cs=Sfts_c./((1/(n1*p))*tr(Sfts_c*Sfts_c'));
        Tfts_c=(Tfts*(Matrix{Float64}(I, n2, n2)-(1/n2)*ones(n2,1)*ones(1,n2)));
        Tfts_cs=Tfts_c./((1/(n2*p))*tr(Tfts_c*Tfts_c'));
        M,Σ=compute_statistics(Sfts_cs,Tfts_cs,ns);
        M_cens=sum((ns[j]./n1)*Mi[:,j] for j=1:m)
        X_test=(X_test-M_cens*ones(1,sum(n_test[1:2])))./((1/(n1*p))*tr(Sfts_c*Sfts_c'));
        M_cent=sum((ns[m*(k-1)+j]./n2)*Mi[:,m*(k-1)+j] for j=1:m)
        X_test1=(X_test1-M_cent*ones(1,sum(n_test[3:4])))./((1/(n2*p))*tr(Tfts_c*Tfts_c'));
        Z=[Sfts_cs zeros(p,n2);zeros(p,n1) Tfts_cs];
        Q=inv((1/(k*p))*Z'*A*Z+Matrix{Float64}(I, n, n));
        y=[Slabels;Tlabels];
        b=(P'*Q*P)\(P'*Q*y);
        e_s=zeros(k,1);e_t=zeros(k,1);
        e_s[1]=1;e_t[2]=1;
        gx_s=(1/(k*p))*kron(e_s,X_test)'*A*Z*Q*(y-P*b)+b[1]*ones(sum(n_test[1:2]),1);
        gx_t=(1/(k*p))*kron(e_t,X_test1)'*A*Z*Q*(y-P*b)+b[2]*ones(sum(n_test[3:4]),1);
        M0=zeros(k*p,m*k);
        for i=1:k
                e_i=zeros(k,1);e_i[i]=1;
                for j=1:m
                        M0[:,m*(i-1)+j]=kron(e_i,M[:,m*(i-1)+j]);
                end
        end
        gx_s2=(1/(k*p))*M0[:,1:2]'*A*Z*Q*(y-P*b)+[b[1];b[1]];
        gx_t2=(1/(k*p))*M0[:,3:4]'*A*Z*Q*(y-P*b)+[b[2];b[2]];
        μ_emp=zeros(m*k,1);
        σ_emp=zeros(m*k,1);
        σ_th=zeros(m*k,1);
        μ_emp[1]=mean(gx_s[1:n_test[1]]);
        μ_emp[2]=mean(gx_s[1+n_test[1]:n_test[1]+n_test[2]]);
        μ_emp[3]=mean(gx_t[1:n_test[3]]);
        μ_emp[4]=mean(gx_t[n_test[3]+1:n_test[3]+n_test[4]]);
        σ_emp[1]=std(gx_s[1:n_test[1]]).^2;
        σ_emp[2]=std(gx_s[1+n_test[1]:n_test[1]+n_test[2]]).^2;
        σ_emp[3]=std(gx_t[1:n_test[3]]).^2;
        σ_emp[4]=std(gx_t[n_test[3]+1:end]).^2;
        #Calcul covariance généralisé
        C=zeros(k*p,k*p,m*k);
        for i=1:k
                e_i=zeros(k,1);e_i[i]=1;
                for j=1:m
                        C[:,:,m*(i-1)+j]=A^(1/2)*kron(e_i*e_i',Σ[:,:,m*(i-1)+j]+M[:,m*(i-1)+j]*M[:,m*(i-1)+j]')*A^(1/2);
                end
        end
        #Calcul du delta
        δ,Qbar=delta_function(C,k,m,p,ns);
        M_δ=M0*diagm(vec(ones(m*k,1)./(ones(m*k,1)+δ)));
        Mg=A^(1/2)*M0*diagm(sqrt.(vec(c./(co*(ones(m*k,1)+δ)))));
        invQtilde0=zeros(k*p,k*p);
        for i=1:k
                e_i=zeros(k,1);e_i[i]=1;
                for j=1:m
                        invQtilde0=invQtilde0+kron((diagm(γ)+λ*ones(k,1)*ones(1,k))^(1/2)*e_i*e_i'*(diagm(γ)+λ*ones(k,1)*ones(1,k))^(1/2),(c[m*(i-1)+j]./(co*(1+δ[m*(i-1)+j])))*Σ[:,:,m*(i-1)+j]);
                end
        end
        Qtilde0=inv(invQtilde0+Matrix{Float64}(I, k*p, k*p));
        Γ=inv(Matrix{Float64}(I, m*k, m*k)+Mg'*Qtilde0*Mg);
        yo=[1-b[1];-1-b[1];1-b[2];-1-b[2]];
        J=zeros(n,m*k);
        J[1:ns[1],1]=ones(ns[1],1);J[ns[1]+1:ns[1]+ns[2],2]=ones(ns[2],1);J[n1+1:n1+ns[3],3]=ones(ns[3],1);J[n1+ns[3]+1:end,4]=ones(ns[4],1);
        gx_s1=(1/(k*p))*kron(e_s,X_test)'*A*Z*Q*(y-P*b)+b[1]*ones(sum(n_test[1:2]),1);
        #μ_th2=(1/(k*p))*M0'*A^(1/2)*Qbar*A^(1/2)*M_δ*J'*(y-P*b)+[b[1];b[1];b[2];b[2]];
        Ygotique=[1;-1;1;-1];Ygotique0=[1-b[1];-1-b[1];1-b[2];-1-b[2]];
        δ_tilde=c./(co*(ones(m*k,1)+δ));
        μ_th=Ygotique-diagm(vec(δ_tilde))^(-1/2)*Γ*diagm(vec(δ_tilde))^(1/2)*Ygotique0;
        S1=zeros(k*p,k*p,m*k);
        d=zeros(k*m,1);
        Tbar=zeros(m*k,m*k);
        Tg=zeros(m*k,m*k);
        for i=1:k
                e_i=zeros(k,1);e_i[i]=1;
                for j=1:m
                        S1[:,:,m*(i-1)+j]=kron(e_i*e_i',Σ[:,:,m*(i-1)+j]);
                        d[m*(i-1)+j]=ns[m*(i-1)+j]./(k*p*(1+δ[m*(i-1)+j])^2)
                        for ip=1:k
                                for jp=1:m
                                        Tbar[m*(i-1)+j,m*(ip-1)+jp]=(1/(k*p))*tr(C[:,:,m*(ip-1)+jp]*Qbar*A^(1/2)*S1[:,:,m*(i-1)+j]*A^(1/2)*Qbar);
                                        Tg[m*(i-1)+j,m*(ip-1)+jp]=(1/(k*p))*tr(C[:,:,m*(i-1)+j]*Qbar*C[:,:,m*(ip-1)+jp]*Qbar);
                                end
                        end
                end
        end
        D=diagm(vec(d));
        T=Tbar/(Matrix{Float64}(I, m*k, m*k)-D*Tg);
        κ=zeros(k*m,k*m);
        V=zeros(k*p,k*p,m*k);
        for i=1:k
                for j=1:m
                        for ip=1:k
                                for jp=1:m
                                        κ[m*(i-1)+j,m*(ip-1)+jp]=d[m*(ip-1)+jp]*T[m*(i-1)+j,m*(ip-1)+jp]/δ_tilde[m*(ip-1)+jp];
                                end
                        end
                        Sec=sum(δ_tilde[m*(ip-1)+jp]*κ[m*(i-1)+j,m*(ip-1)+jp]*A^(1/2)*S1[:,:,m*(i-1)+j]*A^(1/2) for ip=1:k,jp=1:m);
                        V[:,:,m*(i-1)+j]=A^(1/2)*S1[:,:,m*(i-1)+j]*A^(1/2)+Sec;
                        σ_th[m*(i-1)+j]=yo'*diagm(vec(δ_tilde))^(1/2)*(Γ*diagm(vec(κ[m*(i-1)+j,:]))*Γ+Γ*Mg'*Qtilde0*V[:,:,m*(i-1)+j]*Qtilde0*Mg*Γ)*diagm(vec(δ_tilde))^(1/2)*yo;
                end
        end
        #Calcul de l'erreur empirique et théorique
        pred_source=zeros(sum(n_test[1:2]),1);pred_target=zeros(sum(n_test[3:4]),1);
        threshold_emp=(mean(gx_s[1:n_test[1]])+mean(gx_s[1+n_test[1]:end]))/2;
        pred_source[gx_s.<threshold_emp*ones(sum(n_test[1:2]),1)].=-1;pred_source[gx_s.>threshold_emp*ones(sum(n_test[1:2]),1)].=-1;
        threshold_emp_target=(mean(gx_t[1:n_test[3]])+mean(gx_s[1+n_test[3]:end]))/2;
        pred_target[gx_t.<threshold_emp_target*ones(sum(n_test[3:4]),1)].=-1;pred_target[gx_t.>threshold_emp_target*ones(sum(n_test[3:4]),1)].=-1;
        error_source_emp=sum(pred_source!=y_test)./sum(n_test[1:2]);
        error_target_emp=sum(pred_target!=y_test1)./sum(n_test[3:4]);
        error_source_th=(n_test[1]/sum(n_test[1:2]))*0.5*erfc((μ_th[1]-μ_th[2])./(sqrt(2)*sqrt(σ_th[1])))+(n_test[2]/sum(n_test[1:2]))*0.5*erfc((μ_th[1]-μ_th[2])./(sqrt(2)*sqrt(σ_th[2])));
        error_target_th=(n_test[3]/sum(n_test[3:4]))*0.5*erfc((μ_th[3]-μ_th[4])./(sqrt(2)*sqrt(σ_th[3])))+(n_test[4]/sum(n_test[3:4]))*0.5*erfc((μ_th[3]-μ_th[4])./(sqrt(2)*sqrt(σ_th[4])));
        # A faire l'optimization par rapport à y.
        return gx_s,gx_t,μ_th,μ_emp,σ_th,σ_emp,error_source_emp,error_target_emp,error_source_th,error_target_th
end

function compute_statistics(Sfts,Tfts,ns)
        M=zeros(p,4);
        Σ=zeros(p,p,4);
        M[:,1]=mean(Sfts[:,1:ns[1]],dims=2);
        M[:,2]=mean(Sfts[:,ns[1]+1:end],dims=2);
        M[:,3]=mean(Tfts[:,1:ns[3]],dims=2);
        M[:,4]=mean(Tfts[:,ns[3]+1:end],dims=2);
        Σ[:,:,1]=(Sfts[:,1:ns[1]]-M[:,1]*ones(1,ns[1]))*(Sfts[:,1:ns[1]]-M[:,1]*ones(1,ns[1]))'./ns[1];
        Σ[:,:,2]=(Sfts[:,ns[1]+1:end]-M[:,2]*ones(1,ns[2]))*(Sfts[:,ns[1]+1:end]-M[:,2]*ones(1,ns[2]))'./ns[2];
        Σ[:,:,3]=(Tfts[:,1:ns[3]]-M[:,3]*ones(1,ns[3]))*(Tfts[:,1:ns[3]]-M[:,3]*ones(1,ns[3]))'./ns[3];
        Σ[:,:,4]=(Tfts[:,ns[3]+1:end]-M[:,4]*ones(1,ns[4]))*(Tfts[:,ns[3]+1:end]-M[:,4]*ones(1,ns[4]))'./ns[4];
    return M,Σ
end
function delta_function(C,k,m,p,ns)
        c=ns/sum(ns);
        co=k*p/sum(ns);
        δ_1=rand(m*k,1);δ_2=rand(m*k,1);
        ϵ=1e-6;
        n_iter=0;
        while ((abs(δ_1[1]-δ_2[1])>ϵ) || (abs(δ_1[2]-δ_2[2])>ϵ) || (abs(δ_1[3]-δ_2[3])>ϵ) || (abs(δ_1[4]-δ_2[4])>ϵ))
                n_iter=n_iter+1
                println(n_iter)
                δ_1=δ_2;
                #invQ=zeros(k*p,k*p);
                invQ=sum((c[m*(i-1)+j]/co)*C[:,:,m*(i-1)+j]./(1+δ_1[m*(i-1)+j]) for i=1:k,j=1:m)+Matrix{Float64}(I, k*p, k*p);
                #invQ=invQ;
                δ_2=zeros(m*k,1);
                for i=1:k
                        for j=1:m
                                δ_2[m*(i-1)+j]=(1/(k*p))*tr(C[:,:,m*(i-1)+j]*inv(invQ));
                        end
                end
                println(δ_1-δ_2)
        end
        invQ=sum((c[m*(i-1)+j]/co)*C[:,:,m*(i-1)+j]./(1+δ_2[m*(i-1)+j]) for i=1:k,j=1:m)+Matrix{Float64}(I, k*p, k*p);
        δ=δ_2;Qbar=inv(invQ);
        return δ,Qbar
end

#using PGFPlotsX
#Define the data generation process
p=200;
ns=convert.(Int,(floor.([2.5,2.3,1.7,1.4]*p)));
k=2;m=2;β=0.5;
n_test=convert.(Int,1000*ones(m*k,1));
M,Σ=generate_statistic(k,m,p);
data_type="synthetic";
struct S
    fts
    labels
end
struct T
    fts
    labels
end
γ=[1;1];λ=1;
Sfts,Slabels,Tfts,Tlabels,X_test,y_test,X_test1,y_test1=generate_data(p,ns,k,m,β,M,Σ,data_type,n_test,S,T)
gx_s,gx_t,μ_th,μ_emp,σ_th,σ_emp,error_source_emp,error_target_emp,error_source_th,error_target_th=RMTMTLLSSVM(Sfts,Slabels,Tfts,Tlabels,λ,γ,X_test,y_test,X_test1,y_test1,ns)
#PyPlot.close()   # close all plot windows

# fig,axes = subplots(1,2)   # create figure and axes
# ax = axes[1]     # operate on subplot 1
# ax[:plot](Normal(μ_th[1],σ_th[1]),color="red")
# ax[:plot](Normal(μ_th[2],σ_th[2]),color="green")
# ax = axes[2]      # operate on subplot 2
# ax[:plot](Normal(μ_th[3],σ_th[3]),color="yellow")
# ax[:plot](Normal(μ_th[4],σ_th[4]),color="black")
# fig[:canvas][:set_window_title]("title of window")

# fig[:canvas][:draw]()   # update the figure to draw if you're doing this in the REPL
# plot(layout = 4)
#normalize = true;
#p=plt.hist(bins[:-1], bins, weights=counts)
#p=normalize(histogram(gx_s[1:n_test[1]]));
p1=histogram(gx_s[n_test[1]+1:end],normalize = true,alpha=0.5,bins=100);
p1=histogram!(gx_s[1:n_test[1]],normalize = true,alpha=0.5,bins=100);
p1=plot!(Normal(μ_th[1],sqrt(σ_th[1])));
p1=plot!(Normal(μ_th[2],sqrt(σ_th[2])));
display(p1)
p2=histogram(gx_t[n_test[3]+1:end],normalize = true,alpha=0.5,bins=100);
p2=histogram!(gx_t[1:n_test[3]],normalize = true,alpha=0.5,bins=100);
p2=plot!(Normal(μ_th[3],sqrt(σ_th[3])));
p2=plot!(Normal(μ_th[4],sqrt(σ_th[4])));
display(p2)
println(μ_th)
println(μ_emp)
println(σ_th)
println(σ_emp)
# plot(Normal(μ_th[1],σ_th[1]));

#Define function generate_statistic
